import torch
from enum import Enum, auto
from dataset.sim_1d_no_x import Sim1d_noX
from dataset.ticket import Data
from utils.data_class import PVTrainDataSet
from torch.utils.data import DataLoader, Dataset


class DataloaderMode(Enum):
    train = auto()
    test = auto()
    val = auto()


def create_dataloader(cfg, mode, dataset):
    data_loader = DataLoader
    if mode is DataloaderMode.train:
        return data_loader(
            dataset=dataset,
            batch_size=cfg.network.batch_size,
            shuffle=True,
            num_workers=cfg.network.num_workers,
            pin_memory=True,
        )
    elif mode is DataloaderMode.val:
        return data_loader(
            dataset=dataset,
            batch_size=cfg.network.batch_size,
            shuffle=False,
            num_workers=cfg.test.num_workers,
            pin_memory=True,
        )
    elif mode is DataloaderMode.test:
        dataset1 = Dataset_(cfg, mode)
        return data_loader(
            dataset=dataset1,
            batch_size=cfg.test.batch_size,
            shuffle=False,
            num_workers=cfg.test.num_workers,
            pin_memory=True,
            drop_last=False,
        )
    else:
        raise ValueError(f"invalid dataloader mode {mode}")


class Dataset_(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        

    def __len__(self):
        return len(self.dataset.treatment)

    def __getitem__(self, idx):
        sample = (
            self.dataset.treatment[idx],
            self.dataset.treatment_proxy[idx],
            self.dataset.outcome_proxy[idx],
            self.dataset.outcome[idx],
            torch.tensor([]) if self.dataset.backdoor is None else self.dataset.backdoor[idx]
        )
        return sample
    

class ExtendedDataset(Dataset):
    def __init__(self, base_dataset, new_data_column):
        self.base_dataset = base_dataset
        self.new_data_column = new_data_column

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        base_sample = self.base_dataset[idx]
        new_data = self.new_data_column[idx]  # Get the new data for the current index

        extended_sample = (*base_sample, new_data)  # Add new data to the base sample
        return extended_sample
